#!/usr/bin/env python
"""
Quick Plot - QUOKKA simulation visualization tool base on YT

This script provides a command-line interface for creating slice plots and projection plots
from QUOKKA simulation data using the yt library. It's designed to work with AMReX/BoxLib
format simulation outputs, particularly from the QUOKKA radiation-hydrodynamics code.

MAIN FEATURES:
=============
- Create slice plots or projection plots of various physical fields
- Support for multiple field types: density, temperature, velocity components, number density, momentum density
- Particle annotation with customizable markers, sizes, and colors  
- Parallel processing for batch plotting multiple snapshots
- Time-based snapshot filtering to create evenly spaced sequences
- Grid and cell edge annotations for mesh visualization
- Customizable plot parameters: width, zoom limits, colormaps, figure size
- Custom text annotations with LaTeX support

SUPPORTED FIELDS:
================
- Density: 'density', 'rho', 'den' -> ("gas", "density")
- Number density: 'n', 'nH', 'n_H', 'num_density' -> ("gas", "number_density") 
- Temperature: 'temperature', 'T', 'temp' -> ("gas", "temperature")
- Velocity components: 'vx', 'vy', 'vz' -> ("gas", "velocity_x/y/z")
- Velocity magnitude: 'v', 'velocity' -> ("gas", "velocity")
- Momentum density: 'p', 'momentum' -> ("gas", "momentum_density")
- Boxlib fields: any other string -> ("boxlib", field_name)

PLOT TYPES:
===========
- Slice plots ('slc', 'slice'): 2D cross-sections through 3D data
- Projection plots ('prj', 'proj', 'projection'): Line-of-sight integrated quantities

USAGE EXAMPLES:
===============
Basic density slice plot (the following commands are equivalent)
    ./quick_plot plt00001 
    ./quick_plot plt00001 -f rho
    ./quick_plot plt00001 --field rho

Temperature projection of all snapshots with custom width
    ./quick_plot plt* -f T --kind proj --width 10_kpc

Batch processing with particles and custom output
    ./quick_plot plt00* -f rho --particles CIC_particles --outdir figures -j 4

Time-filtered sequence with annotations. This is helpful when the output is too dense or is not regularly spaced in time.
The '--grids' adds AMReX box boundaries to the plot. The '--top_left_text' adds text to the top-left corner of the plot.
    ./quick_plot plt* --time_interval 1_Myr --grids --top_left_text "Simulation X"

Advanced customization:
    ./quick_plot plt* --field T --cmap hot --figsize 8 --zlim 1e3 1e6 --p_size 200 --p_marker "*"

A comprehensive example:
    ./quick_plot plt* --outdir figures -f T --kind proj --width 10_kpc --zlim 1e3 1e6 --p_size 200 --p_marker "*" --p_color "red" --grids --cell_edges --top_left_text "Simulation X" --time_interval 1_Myr --center 1e10,1e10,1e10 -j 4

KEY FUNCTIONS:
==============
- plot_one(pltdir, args): Creates a single plot for one snapshot directory, takes args namespace directly
- filter_snapshots_by_time_interval(): Filters snapshots to create evenly spaced time sequences
- main(args): Main orchestration function for batch processing, takes args namespace directly
- parse_args(): Command-line argument parsing with extensive options
- main(args): Entry point that coordinates the plotting workflow

TECHNICAL DETAILS:
==================
- Requires yt installed from: https://github.com/chongchonghe/yt
- Supports AMReX/BoxLib simulation formats
- Handles derived field creation for temperature and number density calculations, in which case mu = 1 m_p and gamma = 5/3 are assumed. The mean molecular weight can be specified with '--mean_molecular_weight' in atomic mass units.
- Automatically filters out .old. directories
- Supports both single-threaded and multi-process execution
- Creates high-resolution output (300 DPI) with tight bounding boxes and padding

PARTICLE SUPPORT:
=================
- Annotates particle positions from multiple particle types simultaneously
- Customizable particle appearance (size, marker style, color)
- Automatic color cycling for multiple particle types, with default colors
- Depth-based filtering to avoid cluttered visualization
- Handles missing or empty particle datasets gracefully
- Supports multiple particle types simultaneously

OUTPUT:
=======
- Saves plots in PNG format with descriptive filenames
- Optional custom output directory creation
- Supports skipping existing files for incremental processing
"""

import os
import argparse
import numpy as np
from multiprocessing import Pool, cpu_count
import pprint
import yt
import unyt
import matplotlib.pyplot as plt

try:
    import scienceplots
    plt.style.use(['science', 'nature', 'no-latex'])
    print("scienceplots loaded successfully")
except ImportError:
    print("scienceplots not installed; using default matplotlib style")
except Exception as e:
    print(f"scienceplots installed but failed to load style: {e}; using default matplotlib style")

# check yt version
assert yt.__version__ >= "4.3.0", "yt version must be >= 4.3.0"

yt.set_log_level(40)

m_u = 1.660539e-24 * unyt.g


def plot_one(pltdir, args):

    print(f"processing {pltdir}")

    # get index number of pltdir
    idx = int(os.path.basename(pltdir)[-5:])

    ds = yt.load(pltdir)
    ad = ds.all_data()
    # print(ds.derived_field_list)
    # return

    # Extract parameters from args
    field = args.field
    kind = args.kind
    width = args.width
    zlim = args.zlim
    particle = args.particles
    grids = args.grids
    cell_edges = args.cell_edges
    time_off = args.timeoff
    view_dir = args.dir
    center = args.center
    top_left_text = args.top_left_text
    mean_molecular_weight = args.mean_molecular_weight
    figsize = args.figsize
    cmap = args.cmap
    p_size = args.p_size
    p_marker = args.p_marker
    p_color = args.p_color
    skip_existing = args.skip_existing
    annotate_center = args.annotate_center
    axis_unit = args.axis_unit
    outdir = args.outdir
    hide_all = args.hide_all
    hide_axes = args.hide_axes

    mean_molecular_weight_per_H_atom = mean_molecular_weight * m_u

    # add derived fields
    if field == ("gas", "number_density"):
        # be sure to use a name that does not conflict with existing fields. Do not use ("gas", "number_density")!!!
        ds.add_field(field, function=lambda field, data: data[(
            "gas", "density")] / mean_molecular_weight_per_H_atom, units="cm**-3", sampling_type="cell")
    elif field == ("gas", "temperature"):
        # add derived field
        k_B = unyt.physical_constants.boltzmann_constant
        gamma = 5.0 / 3.0
        ds.add_field(field, function=lambda field, data: data[("gas", "internal_energy_density")] * (gamma - 1.0) / (
            data[('gas', 'density')] / mean_molecular_weight_per_H_atom * k_B), units="K", sampling_type="cell")
    elif field == ("gas", "velocity"):
        # add derived field for velocity magnitude
        ds.add_field(field, function=lambda field, data: np.sqrt(data[("gas", "velocity_x")]**2 + data[(
            "gas", "velocity_y")]**2 + data[("gas", "velocity_z")]**2), units="cm/s", sampling_type="cell")
    elif field == ("gas", "momentum_density"):
        # add derived field for momentum density
        ds.add_field(field, function=lambda field, data: np.sqrt(data[("gas", "momentum_density_x")]**2 + data[(
            "gas", "momentum_density_y")]**2 + data[("gas", "momentum_density_z")]**2), sampling_type="cell")

    field_root = field if not isinstance(field, tuple) else field[1]

    # plot slice or projection
    field_rho = ("gas", "density")
    if kind in ["slc", "slice"]:
        kind = "slc"
        slc = yt.SlicePlot(ds, view_dir, field, center=center)
    elif kind in ["prj", "proj", "projection"]:
        kind = "prj"
        slc = yt.ProjectionPlot(ds, view_dir, field,
                                weight_field=field_rho, center=center)
    else:
        raise ValueError(f"kind {kind} not supported")

    # take a guess on the output filename: e.g. plt00008_Slice_z_density.png, and skip if it already exists
    fn_slc = {"slc": "Slice", "prj": "Projection"}[kind]
    fig_name = f"{ds.basename}_{fn_slc}_{view_dir}_{field_root}.png"
    if skip_existing and os.path.exists(os.path.join(outdir, fig_name)):
        print(f"skipping existing figure: {fig_name}")
        return

    if cmap == "default":
        cmap = "viridis"
        if field == ("gas", "temperature"):
            cmap = "hot"
    slc.set_cmap(field, cmap)

    # slc.set_log(field, True)
    slc.set_background_color(field, 'black')
    if width is not None:
        # if ',' or '_' is in width, then it is a tuple (float, unit)
        if ',' in width:
            w = (float(width.split(',')[0]), width.split(',')[1])
        elif '_' in width:
            w = (float(width.split('_')[0]), width.split('_')[1])
        else:
            w = float(width)
        slc.set_width(w)
    if axis_unit is not None:
        slc.set_axes_unit(axis_unit)
        print(f"set axes unit to {axis_unit}")
    if zlim is not None:
        zlim0 = 'min' if zlim[0].lower() == "min" else float(zlim[0])
        zlim1 = 'max' if zlim[1].lower() == "max" else float(zlim[1])
        # special case for idx == 0, usually has uniform density
        if idx == 0 and zlim1 == 'max':
            # find the max density
            if zlim0 == 'min':
                den_min = ad.min(field)
                den_max = ad.max(field)
                if np.isclose(den_min, den_max):
                    zlim0 = 0.9 * den_min
                    zlim1 = 1.1 * den_max
            else:
                zlim1 = 1.1 * zlim0
        slc.set_zlim(field, zlim0, zlim1)
    if len(particle) > 0:
        # check if particles exist
        if 'particles' in ds.parameters.keys():
            # print(ds['particle_info']['CIC_particles']['num_particles'])
            colors = ['red', 'magenta', 'cyan', 'yellow',
                      'lime', 'hotpink', 'orange', 'deepskyblue']
            if field == ("gas", "temperature"):
                colors = ['green', 'lime', 'cyan',
                          'hotpink', 'orange', 'deepskyblue']
            # get domain width
            Lx = ds.domain_right_edge[0] - ds.domain_left_edge[0]
            for i, par in enumerate(particle):
                if par not in ds['particle_info'].keys():
                    print("particle ", par, " not found in ", pltdir)
                    continue
                num_particles = ds['particle_info'][par]['num_particles']
                if num_particles == 0:
                    print(f"no {par} particles to annotate in {pltdir}")
                    continue

                # get particle position
                try:
                    pos = ad[(par, "particle_position_x")]
                except yt.utilities.exceptions.YTFieldNotFound:
                    print(f"no {par} particles to annotate in {pltdir}")
                    continue
                if len(pos) > 0:
                    # slc.annotate_particles(Lx, p_size=160., col=colors[i], marker='*', ptype=par)
                    # annotate particles at a depth of 0.1 * boxsize
                    color = p_color if p_color is not None else colors[i]
                    slc.annotate_particles(
                        Lx * 0.1, p_size=p_size, col=color, marker=p_marker, ptype=par)
                else:
                    print("no particles to annotate in ", pltdir)
                    print("pos =", pos)
        else:
            print(f"no particles in ds.parameters, {pltdir}")
            # print(ad.keys())
            print(ds.parameters)
    if grids:
        slc.annotate_grids(edgecolors='white', linewidth=1)
    if cell_edges:
        slc.annotate_cell_edges(line_width=0.001, color='black')
    if not time_off:
        slc.annotate_timestamp()
    if top_left_text is not None:
        # Add text annotation at the top-left corner with LaTeX support
        slc.annotate_text((0.02, 0.98), top_left_text, coord_system='axis', text_args={'color': 'white', 'usetex': True, 'verticalalignment': 'top', 'horizontalalignment': 'left'})
    if annotate_center is not None:
        # Add text annotation at the center of the domain
        slc.annotate_text((0.5, 0.5), annotate_center, coord_system='axis', text_args={'color': 'red', 'usetex': True, 'verticalalignment': 'center', 'horizontalalignment': 'center'})
    # get cwd
    cwd = os.getcwd()
    # change to outdir
    os.chdir(outdir)
    slc.set_figure_size(figsize)  # slightly smaller figure size

    if hide_all:
        slc.hide_axes(draw_frame=True)
        slc.set_colorbar_label(field, "")
        slc.hide_colorbar()
    else:
        # slc.set_colorbar_location("bottom")
        if hide_axes:
            slc.hide_axes(draw_frame=True)

    fn = slc.save(
        mpl_kwargs={"dpi": 300, "bbox_inches": "tight", "pad_inches": 0.1})
    print(f"{fn} saved")
    # change back to cwd
    os.chdir(cwd)


def filter_snapshots_by_time_interval(pltdirs, time_interval):
    """Filter snapshots to only include those closest to n * time_interval where n = 0, 1, 2, ..."""
    if time_interval is None:
        return pltdirs, None, None

    if '_' in time_interval:
        time_interval, time_unit = float(time_interval.split('_')[
                                         0]), time_interval.split('_')[1]
    else:
        time_interval = float(time_interval)
        time_unit = None

    # Get times for all snapshots
    snapshot_times = []
    for pltdir in pltdirs:
        try:
            ds = yt.load(pltdir)
            snapshot_times.append(
                (pltdir, ds.current_time.to_value(time_unit)))
        except Exception as e:
            print(f"Warning: Could not load {pltdir}: {e}")
            continue

    if not snapshot_times:
        return [], None, None

    # Sort by time
    snapshot_times.sort(key=lambda x: x[1])

    # Find snapshots closest to n * time_interval
    filtered_pltdirs = []
    filtered_times = []
    current_interval = 0

    for pltdir, time in snapshot_times:
        target_time = current_interval * time_interval
        if time >= target_time:
            filtered_pltdirs.append(pltdir)
            filtered_times.append(time)
            current_interval += 1

    return filtered_pltdirs, filtered_times, time_unit


def test_filter_snapshots_by_time_interval():

    times = [0.0, 1.1, 1.9, 4.1, 4.2, 4.9, 5.1, 5.9, 6.1, 7.3]
    times.sort()
    time_interval = 1.0

    filtered_times = []
    current_interval = 0
    for time in times:
        target_time = current_interval * time_interval
        if time >= target_time:
            filtered_times.append(time)
            current_interval += 1

    print("Original times:")
    print(times)
    print("Filtered times:")
    print(filtered_times)


def main(args):

    if args.print_field_list:
        ds = yt.load(args.pltdirs[0])
        print("ds.derived_field_list:")
        pprint.pprint(ds.derived_field_list)
        return

    assert args.dir in ["x", "y", "z"]

    if args.outdir != ".":
        os.makedirs(args.outdir, exist_ok=True)

    # Map field names to yt field tuples
    if args.field is None or args.field in ["density", "rho", "den"]:
        args.field = ("gas", "density")
    elif args.field in ["n", "nH", "n_H", "num_density"]:
        args.field = ("gas", "number_density")
    elif args.field in ["temperature", "T", "temp"]:
        args.field = ("gas", "temperature")
    elif args.field in ['vx', 'velocity-x', 'velocity_x']:
        args.field = ("gas", "velocity_x")
    elif args.field in ['vy', 'velocity-y', 'velocity_y']:
        args.field = ("gas", "velocity_y")
    elif args.field in ['vz', 'velocity-z', 'velocity_z']:
        args.field = ("gas", "velocity_z")
    elif args.field in ['v', 'velocity']:
        args.field = ("gas", "velocity")
    elif args.field in ['p', 'momentum']:
        args.field = ("gas", "momentum_density")
    else:
        args.field = ("boxlib", args.field)

    # Filter out invalid directories and old directories
    valid_pltdirs = [
        pltdir for pltdir in args.pltdirs if os.path.isdir(pltdir) and ".old." not in os.path.basename(pltdir)
    ]

    # Filter snapshots by time interval if specified
    if args.time_interval is not None:
        print(f"Filtering snapshots by time interval of {args.time_interval} ...")
        valid_pltdirs, filtered_times, time_unit = filter_snapshots_by_time_interval(
            valid_pltdirs, args.time_interval)
        print(f"Filtered times: {filtered_times} {time_unit}")
        print(
            f"Selected {len(valid_pltdirs)} snapshots based on time interval of {args.time_interval}")
    print(f"Valid pltdirs: {valid_pltdirs}")

    # parse center
    if args.center is not None:
        if ',' in args.center:
            args.center = tuple(float(x) for x in args.center.split(','))
        elif '_' in args.center:
            args.center = tuple(float(x) for x in args.center.split('_'))
        else:
            args.center = float(args.center)
    else:
        args.center = 'c'

    if args.n_processes == 1 or args.first_only:
        if args.first_only:
            valid_pltdirs = [valid_pltdirs[0]]
        for pltdir in valid_pltdirs:
            plot_one(pltdir, args)
    else:
        # Create arguments for parallel processing
        plot_args = [(pltdir, args) for pltdir in valid_pltdirs]

        print(
            f"Processing {len(valid_pltdirs)} directories using {args.n_processes} processes")

        # Process directories in parallel
        with Pool(processes=args.n_processes) as pool:
            pool.starmap(plot_one, plot_args)


def parse_args():
    parser = argparse.ArgumentParser()
    # plotfile directories, require at least one
    parser.add_argument("pltdirs", type=str, nargs="+",
                        help="Plotfile directories, require at least one. Use wildcards like plt00* to select multiple directories. Files with .old. in the name will be ignored.")
    # task
    parser.add_argument("--task", type=str, default="slc",
                        help="Task to perform: slc or proj. Default: slc")
    # pick field to plot
    parser.add_argument("-f", "--field", type=str, default="density",
                        help="Field to plot. Options: density, n, nH, T, vx, vy, vz, v, p. Default: density")
    # kind of plot: slc or proj
    parser.add_argument("--kind", type=str, default="slc",
                        help="Kind of plot: slc or proj. Default: slc")
    # width, optional
    parser.add_argument("-w", "--width", type=str, default=None,
                        help="width of the plot, e.g. 1 or 10_kpc. Default: None (full width)")
    # output directory
    parser.add_argument("-o", "--outdir", type=str, default=".",
                        help="Directory to save the figures. Default: .")
    # zlim
    parser.add_argument("--zlim", type=str, nargs=2, default=None,
                        help="zlim of the plot, e.g. 1 or 10_kpc. Default: None (automatic)")
    # view direction
    parser.add_argument("--dir", type=str, default="z",
                        help="view direction: x, y, or z. Default: z")
    # center at direction
    parser.add_argument("--center", type=str, default=None,
                        help="center at direction, e.g. 0.5, 1.0_kpc. Default: None (domain center)")
    # particle type to annotate
    parser.add_argument("--particles", type=str, nargs="+", default=[],
                        help="Particle type to annotate, e.g. CIC_particles, StochasticStellarPop_particles. Default: [] (no particles)")
    # toggle annotate grid lines
    parser.add_argument("--grids", action="store_true",
                        help="Annotate grid lines. Default: False")
    # toggle annotate cell edges
    parser.add_argument("--cell_edges", action="store_true",
                        help="Annotate cell edges. Default: False")
    # toggle annotate timestamp
    parser.add_argument("--timeoff", action="store_true",
                        help="Do not annotate timestamp. Default: False")
    # skip existing folder
    parser.add_argument("--skip_existing", action="store_true",
                        help="Skip existing figures. Default: False")
    # number of processes to use
    parser.add_argument("-j", "--n_processes", type=int,
                        default=1, help="Number of processes to use. Default: 1")
    # print field list
    parser.add_argument("--print_field_list", action="store_true",
                        help="Print the list of derived fields and stop. Default: False")
    # plot the first one only
    parser.add_argument("--first_only", action="store_true",
                        help="Plot only the first snapshot. Default: False")
    # text to annotate at top-left corner
    parser.add_argument("--top_left_text", type=str, default=None,
                        help="Text to annotate at the top-left corner in white")
    # text to annotate at top-left corner
    parser.add_argument("--top_right_text", type=str, default=None,
                        help="Text to annotate at the top-right corner in white")
    # use ISM mean molecular weight
    parser.add_argument("--mean_molecular_weight", type=float, default=1.0,
                        help="Mean molecular weight in atomic mass units. Default: 1.0")
    # time interval between snapshots in Myr
    parser.add_argument("--time_interval", type=str, default=None,
                        help="Time interval between snapshots. e.g. 1 or 0.1_Myr. Default: None (no filtering)")
    # figure size (in inches)
    parser.add_argument("--figsize", type=float, default=6,
                        help="Figure size in inches. Default: 6")
    # cmap
    parser.add_argument("--cmap", type=str, default="default",
                        help="Colormap to use, e.g. 'viridis', 'hot', 'jet'. Default: hot for temperature, viridis for everything else")
    # particle size
    parser.add_argument("--p_size", type=float, default=160.,
                        help="Size of particles in the plot. Default: 160.0")
    # particle marker
    parser.add_argument("--p_marker", type=str, default='.',
                        help="Marker style for particles. Common options: '.', '*', 'o', '+', 'x'. Default: '.'")
    # particle color
    parser.add_argument("--p_color", type=str, default=None, 
                        help="Color for particles. Default: None (use default colors)")
    # annotate center
    parser.add_argument("--annotate_center", type=str, default=None, help="Text to annotate at the center of the domain")
    # axis units
    parser.add_argument("--axis_unit", type=str, default=None, help="Units for x and y axis, e.g. 'kpc', 'cm', 'pc'. Default: None (use dataset units)")
    parser.add_argument("--hide_axes", action="store_true", help="Hide hide axes")
    parser.add_argument("--hide_all", action="store_true", help="Hide both cb and cb labels")

    return parser.parse_args()


if __name__ == "__main__":

    # test_filter_snapshots_by_time_interval()

    args = parse_args()
    main(args)
